import pandas as pd
import datetime
import matplotlib.pyplot as plt
import tushare as ts

class Draw():
    def __init__(self, stocks = ['600050', '000063', '600030']):
        self.stocks = stocks

    def draw_plot(self, data, label):
        fig, ax = plt.subplots(figsize=(16, 9))
        ax.plot(data.keys(), data.values(), label = label)
        ax.set_xlabel('Date')
        ax.set_ylabel(label)
        ax.legend()
        plt.show()

    def draw_close(self):
        fig, ax = plt.subplots(figsize=(16, 9))

        for code in self.stocks:
            stock = ts.get_k_data(code = code, ktype = 'D', start='2020-01-01', end='2021-01-01')
            stock['date'] = pd.to_datetime(stock['date'])
            stock = stock.set_index(['date'])
            close = stock['close']

            ax.plot(close.index, close.values, label = code)

        ax.set_xlabel('Date')
        ax.set_ylabel('Close (RMB)')
        ax.legend()
        plt.show()

    def draw_open(self):
        fig, ax = plt.subplots(figsize=(16, 9))

        for code in self.stocks:
            stock = ts.get_k_data(code = code, ktype = 'D', start='2020-01-01', end='2021-01-01')
            stock['date'] = pd.to_datetime(stock['date'])
            stock = stock.set_index(['date'])
            open = stock['open']

            ax.plot(open.index, open, label = code)

        ax.set_xlabel('Date')
        ax.set_ylabel('Open (RMB)')
        ax.legend()
        plt.show()

    def draw_rolling(self):
        fig, ax = plt.subplots(figsize=(16, 9))

        for code in self.stocks:
            stock = ts.get_k_data(code = code, ktype = 'D', start='2020-01-01', end='2021-01-01')
            stock['date'] = pd.to_datetime(stock['date'])
            stock = stock.set_index(['date'])
            close = stock['close']

            rolling_close_short = close.rolling(window = 10).mean()
            rolling_close_long = close.rolling(window = 30).mean()

            ax.plot(rolling_close_short.index, rolling_close_short, label = code + '_Short_Avg')
            ax.plot(rolling_close_long.index, rolling_close_long, label = code + '_Long_Avg')

        ax.set_xlabel('Date')
        ax.set_ylabel('Close (RMB)')
        ax.legend()
        plt.show()

    def draw_all(self):
        self.draw_open()
        self.draw_close()
        self.draw_rolling()


if __name__ == '__main__':
    stocks = ['600621', '000063', '600030']
    draw = Draw(stocks)
    draw.draw_all()